from datasets import load_dataset
from data_prep.TruthyDPOProcessor import TruthyDPOProcessor
from data_prep.PKUProcessor import PKUProcessor
from data_prep.PKUProcessor30K import PKUProcessor30K
from data_prep.UltraFeedbackRDPProcessor import UltraFeedbackRDPProcessor
from data_prep.UltraFeedbackRDPProcessorBinary import UltraFeedbackRDPProcessorBinary
from data_prep.TestProcessor import TestProcessor

def load_and_process_dataset_from_name(dataset_name, split='train', seed=0, removed_dimensions = None):
    """
    Load and process a dataset based on its name and specified configurations.

    Parameters:
        dataset_name (str): The name of the dataset to load.
        split (str): The split of the dataset to load (e.g., 'train', 'test').
        add_column (bool): Whether to add a static column.
        column_name (str): The name of the static column to add.
        column_name (str): The name of the static column to add.
        column_value (any): The value to assign to the new static column.

    Returns:
        dataset: The processed dataset.
    """
    # Define column mappings based on dataset names
    class_mappings = {
        'jondurbin/truthy-dpo-v0.1': TruthyDPOProcessor,
        'PKU-Alignment/PKU-SafeRLHF-10K': PKUProcessor,
        'PKU-Alignment/PKU-SafeRLHF-30K': PKUProcessor30K,
        'openbmb/UltraFeedback': UltraFeedbackRDPProcessorBinary,
        'testdata': TestProcessor,
        # Add additional datasets and their mappings here
    }

    # Load the dataset
    if dataset_name not in class_mappings:
        raise ValueError(f"Dataset '{dataset_name}' is not recognized.")

        # Get the appropriate processor class
    ProcessorClass = class_mappings[dataset_name]

    # Instantiate the processor with the dataset name and split
    processor_instance = ProcessorClass()
    dataset = processor_instance.get_preference_dataset(split, seed, removed_dimensions)

    return dataset


"""
def main():
    dataset_name = "PKU-Alignment/PKU-SafeRLHF-10K"
    dataset = load_dataset(dataset_name, split="train")
    for i, row in enumerate(dataset.select(range(10))):
        print(f"Row {i + 1}: {row}")

    dataset = load_and_process_dataset_from_name(dataset_name=dataset_name, split='test', seed=0)
    print("--------------------testing--------------------")
    print(dataset)
    for key, values in dataset.items():
        print(f"Keys {key}:")
        for i, row in enumerate(values):
            print(f"Row {i + 1}: {row}")
    
    print("-----------------train----------------------")
    dataset = load_and_process_dataset_from_name(dataset_name=dataset_name, split='train', seed=0)
    print(dataset)
    for key, values in dataset.items():
            print(f"Keys {key}:")
            for i, row in enumerate(values):
                print(f"Row {i + 1}: {row}")
    print("-------------------validation-------------")
    dataset = load_and_process_dataset_from_name(dataset_name=dataset_name, split='validation', seed=0)
    print(dataset)
    for key, values in dataset.items():
        print(f"Keys {key}:")
        for i, row in enumerate(values):
            print(f"Row {i + 1}: {row}")
    

if __name__ == "__main__":
    main()
"""